import logging
from collections import defaultdict
from dataclasses import dataclass
from pickle import INST
from re import S
from typing import Optional, List

import itertools
import random
import numpy as np
import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import PreTrainedTokenizerBase

from .. import gist
from difflib import SequenceMatcher
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

logger = logging.getLogger(__name__)


def get_prepend(method, icl_dataset, tokenizer, update_tokens, token_name_dict, target, num_demonstrations=1, idx_dict=None, use_scalar_encode=False):

    prep = ""
    prompt_len = []
    icl_labels = []

    for idx in range(num_demonstrations):
        if method == "rand":
            rand_idx = random.randint(0, len(icl_dataset) - 1)
            instance = icl_dataset.__getitem__(rand_idx)
        elif method == "knn": 
            knn_idx = idx_dict[target["idx"]][idx]
            instance = icl_dataset.__getitem__(knn_idx)
        elif method == "debug":
            instance = target
            
        instance_formulation = instance["formulation"]
        for prompt_token in update_tokens:
            instance_formulation = instance_formulation.replace(prompt_token, token_name_dict[prompt_token])

        if isinstance(instance['input'], list):
            for ip_idx, ip in enumerate(instance['input']):
                instance_formulation = instance_formulation.replace("<input " + str(ip_idx) +">", ip)
                if ip_idx < 2:    
                    prompt_len.append(len(tokenizer(ip)["input_ids"]) - 1)
        else:
            instance_formulation = instance_formulation.replace("<input>", instance["input"])
            prompt_len.append(len(tokenizer(instance["input"])["input_ids"]) - 1)
        instance_formulation = instance_formulation.replace("<output>", str(instance["output"]))

        prep += instance_formulation + " \n\n"
        icl_labels.append(float(instance["output"]))

    return prep, prompt_len, icl_labels



@dataclass
class DataCollatorForAlpacaCLM:
    """Data collator for decoder-only models. Does left padding."""
    
    tokenizer: PreTrainedTokenizerBase
    icl_dataset: None
    method: None
    num_demonstrations: None
    idx_dict: None
    max_length: Optional[int] = None
    label_pad_token_id: int = -100
    return_tensors: str = "pt"
    pad_token: int = 0
    pretrained_vocab_size: int = 32000
    token_name_dict: dict = None
    start_markers: List[str] = None
    update_tokens: List[str] = None
    check_correctness: bool = False
    eval_mode: bool = False
    num_token_per_prompt: int = 10
    use_scalar_encode: bool = False
    use_end_marker: bool = False
    use_functional_token: bool = True
    add_ce_loss: bool = False
    autoregressive_attn_mask: bool = False

    def __call__(self, batch, return_tensors=None):

        max_length = self.max_length

        if return_tensors is None:
            return_tensors = self.return_tensors

        model_inputs = defaultdict(list)

        reg_idx = []
        reg_pred_idx = None
        reg_dim = []
        
        clf_idx = []
        
        prompt_idx = []
        prompt_len = []
        ori_len = []
        icl_labels = []
        
        num_passed = 0

        tokenized_function_token = self.tokenizer(self.token_name_dict[self.update_tokens[-1]])["input_ids"][1:]

        for idx, instance in enumerate(batch):
            prompt_len.append([])

            instance_formulation = instance["formulation"]
            for prompt_token in self.update_tokens:
                instance_formulation = instance_formulation.replace(prompt_token, self.token_name_dict[prompt_token])

            if isinstance(instance['input'], list):
                for ip_idx, ip in enumerate(instance['input']):
                    instance_formulation = instance_formulation.replace("<input " + str(ip_idx) +">", ip)
                    if ip_idx < 2:    
                        prompt_len[-1].append(len(self.tokenizer(ip)["input_ids"]) - 1)
            else:
                instance_formulation = instance_formulation.replace("<input>", instance["input"])
                prompt_len[-1].append(len(self.tokenizer(instance["input"])["input_ids"]) - 1)
            instance_formulation = instance_formulation.replace("<output>", str(instance["output"]))
            # print(instance_formulation)

            if self.icl_dataset is not None:
                prep, preplens, labs = get_prepend(self.method, self.icl_dataset, self.tokenizer, self.update_tokens, self.token_name_dict, instance, self.num_demonstrations, self.idx_dict, self.use_scalar_encode)
                instance_formulation = prep + instance_formulation
                prompt_len[-1] = preplens + prompt_len[-1]
                for l in labs:
                    icl_labels.append(l)

            tokenized_input = self.tokenizer(instance_formulation)["input_ids"] + [self.tokenizer.eos_token_id]
            if instance["task"] == "Generation":
                tokenized_input = tokenized_input[1:-1]

            if instance["regression"]:

                label_reg = float(instance["output"])
                labels = [self.label_pad_token_id] * len(tokenized_input)

            if not instance["regression"] or self.add_ce_loss:
                
                labels = tokenized_input.copy()
                special_token_idx = np.argwhere((np.array(tokenized_input) >= self.pretrained_vocab_size) | (np.array(tokenized_input) == 0)).flatten()
                labels = np.array(labels)
                labels[special_token_idx] = self.label_pad_token_id
                labels = labels.tolist()

            if len(tokenized_input) > max_length:
                if num_passed == 0 and idx == len(batch) - 1:
                    to_trim = len(tokenized_input) - max_length
                    labels = labels[:-to_trim]
                    tokenized_input = tokenized_input[:-to_trim]
                else:
                    if instance["regression"] or self.eval_mode:
                        prompt_len = prompt_len[:-1]
                        continue

                    to_trim = len(tokenized_input) - max_length
                    labels = labels[:-to_trim]
                    tokenized_input = tokenized_input[:-to_trim]
            
            # print(instance_formulation)
            # print(tokenized_input, labels, len(tokenized_input))
            model_inputs["input_ids"].append(tokenized_input)
            model_inputs["labels"].append(labels)
            model_inputs["attention_mask"].append([1 for _ in tokenized_input])
            
            if instance["regression"]:
                # print(tokenized_function_token, np.argwhere(np.array(tokenized_input) == tokenized_function_token[-1]).flatten(), tokenized_input)
                if not self.use_functional_token:
                    reg_pred_idx = -2
                else:
                    try:
                        reg_pred_idx = np.argwhere(np.array(tokenized_input) == tokenized_function_token[-1]).flatten()[-1] - len(tokenized_input) #
                    except:
                        reg_pred_idx = -2
                reg_idx.append(num_passed)
                reg_dim.append(instance["regression_dim"])
                model_inputs["label_reg"].append(label_reg)

            if not instance["regression"] or self.add_ce_loss:
                clf_idx.append(num_passed)
                
            if self.use_scalar_encode:
                icl_labels.append(float(instance["output"]))
            
            ori_len.append(len(tokenized_input))
            prompt_idx.append(np.argwhere(np.in1d(np.array(tokenized_input), self.start_markers)).flatten())
            num_passed += 1

        # Left-pad inputs, convert to tensor
        for key, value in model_inputs.items():
            if key == "label_reg":
                model_inputs[key] = torch.tensor(value).float()
            else:
                if key == "labels":
                    pad_token_id = self.label_pad_token_id
                else:
                    pad_token_id = self.tokenizer.pad_token_id

                # To left-pad inputs, reverse, then right-pad, then reverse
                value_tensors = [torch.tensor(v[::-1]) for v in value]
                model_inputs[key] = torch.fliplr(
                    pad_sequence(
                        value_tensors,
                        batch_first=True,
                        padding_value=pad_token_id,
                    )
                )
        
        # mask for domain markers
        bs, max_len = model_inputs["input_ids"].shape
        attn_mask = torch.zeros(bs, max_len, max_len)
        # print("PROMPT IDX", prompt_idx, prompt_len)
        
        if not self.autoregressive_attn_mask:
            for i, idx in enumerate(prompt_idx):
                if len(idx) == 0: continue

                idx += max_len - ori_len[i]

                k = 0
                for j, x in enumerate(idx):
                    if x == 0:
                        continue
                    for y in range(x, x + (1 + int(self.use_end_marker)) * self.num_token_per_prompt + prompt_len[i][k]):
                        if y >= len(model_inputs["input_ids"][i]): break
                        attn_mask[i, y, :x] = 1
                        # print("attn", model_inputs["input_ids"][i][y],model_inputs["input_ids"][i][x],x)
                    k += 1

            model_inputs["attention_mask_gist"] = attn_mask.unsqueeze(1)

        if len(reg_idx) > 0:
            model_inputs["reg_idx"] = reg_idx
            model_inputs["reg_dim"] = reg_dim
            model_inputs["reg_pred_idx"] = reg_pred_idx
        if len(clf_idx) > 0:
            model_inputs["clf_idx"] = clf_idx
        if len(icl_labels) > 0:
            model_inputs["icl_labels"] = torch.tensor(icl_labels).float()
    
        return model_inputs


